import json
from metric.perspective_api import PerspectiveApiScorer
from metric.detoxify import DetoxifyScorer
import argparse
from tqdm import tqdm
import time
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluation")
    parser.add_argument("--input", type=str, default="")
    parser.add_argument("--gpu-id", type=int, default=2, help="specify the gpu to load the model.")
    parser.add_argument("--output", type=str, default="")
    parser.add_argument("--api_key", type=str, default=".perspective_api_key")
    parser.add_argument("--judge", type=str, nargs='+', 
                        choices=['detoxify', 'perspective_api', 'llama_Guard', 'GPT-4o'],
                        default=['GPT-4o']) # 'perspective_api'

    ############# args for GPT-4 judge #############
    parser.add_argument(
        "--judge-model",
        default="gpt-4o-2024-08-06",
        help="Name of judge model. Defaults to the Llama Guard model from JailbreakBench.",
        choices=["gpt-3.5-turbo-1106", "gpt-4o-2024-08-06","no-judge","jailbreakbench","gcg"]
    )
    parser.add_argument(
        "--judge-max-n-tokens",
        type = int,
        default = 10,
        help = "Maximum number of tokens for the judge."
    )
    parser.add_argument(
        "--judge-temperature",
        type=float,
        default=0,
        help="Temperature to use for judge."
    )
    args = parser.parse_args()
    return args

args = parse_args()

input_file = args.input
output_file = args.output


if output_file == "":
    output_file = input_file.split('.jsonl')[0] + '_eval.jsonl'

def float_it(score):
    for key in score.keys():
        score[key] = float(score[key])
    return score

def moderate(chat, tokenizer, model, device):
    input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
    output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
    # print(output)
    prompt_len = input_ids.shape[-1]
    return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)


if 'detoxify' in args.judge:
    # Detoxify
    detoxify_scorer = DetoxifyScorer()
    print('>>> Detoxify Scorer loaded <<<')

if 'perspective_api' in args.judge:
    # Perspective API
    api_key = open(args.api_key).read().strip()
    perspective_scorer = PerspectiveApiScorer(api_key)
    print('>>> Perspective API Scorer loaded <<<')

if 'llama_Guard' in args.judge:
    # Llama-Guard
    model_id = "meta-llama/Llama-Guard-3-8B"
    device = "cuda"
    dtype = torch.bfloat16
    jsonl_path = "0_output/llava_ones_AAAI/output_250416_exp_baseline/class3_iter5000.jsonl"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device)
    print('>>> Llama-Guard Scorer loaded <<<')

if 'GPT-4o' in args.judge:
    # GPT-4
    from GPT_4_judge.judges import load_judge
    judgeLM = load_judge(args)
    print('>>> GPT-4 Scorer loaded <<<')

inputs = open(input_file).read().strip().split('\n')
with open(output_file, 'w') as f:
    for li in tqdm(inputs):
        obj = json.loads(li)

        if 'continuation' in obj:
            chat = [
                {"role": "user", "content": obj["prompt"]},
                {"role": "assistant", "content": obj["continuation"]}
            ]
            text = obj['continuation']
            prompt = obj['prompt']
            if type(text) == list:
                text = text[0]
            if type(prompt) == list:
                prompt = prompt[0]

            if text != "":
                res = {}

                if 'detoxify' in args.judge:
                    # Detoxify
                    res['detoxify'] = float_it(detoxify_scorer.get_scores(text))
                
                if 'perspective_api' in args.judge:
                    # Perspective API
                    res['perspective_api'] = float_it(perspective_scorer.get_scores(text))
                    time.sleep(1.1)
                
                if 'llama_Guard' in args.judge:
                    # Llama-Guard
                    res['llama_Guard'] = moderate(chat, tokenizer, model, device)

                if 'GPT-4o' in args.judge:
                    # GPT-4
                    judge_scores, raw = judgeLM.score([prompt], [text])
                    res['GPT-4o'] = {"score": judge_scores[0], "raw_output": raw[0]}
                
                obj['continuation_metric'] = res
            else:
                obj['continuation_metric'] = {}
                for judge in args.judge:
                    if judge == 'perspective_api':
                        obj['continuation_metric'][judge] = {
                            "identity_attack": 0.0,
                            "profanity": 0.0,
                            "severe_toxicity": 0.0,
                            "sexually_explicit": 0.0,
                            "threat": 0.0,
                            "toxicity": 0.0
                        }
                    elif judge == 'detoxify':
                        obj['continuation_metric'][judge] = {
                            "identity_attack": 0.0,
                            "obscene": 0.0,
                            "severe_toxicity": 0.0,
                            "insult": 0.0,
                            "threat": 0.0,
                            "toxicity": 0.0
                        }
                    elif judge == 'llama_Guard':
                        obj['continuation_metric'][judge] = "\n\nsafe"
                    elif judge == 'GPT-4o':
                        obj['continuation_metric'][judge] = {
                            "score": 101,
                            "raw_output": "empty string"
                        }

        f.write(json.dumps(obj))
        f.write('\n')

